#!/usr/bin/env python3

import marimo

__generated_with = "0.13.15"
app = marimo.App(width="medium")

with app.setup:
    import marimo as mo
    import marimo._schemas.session as s

    import marimo._schemas.notebook as n
    import typing as t

    SESSION_MESSAGES = [
        # Session
        s.TimeMetadata,
        s.StreamOutput,
        s.StreamMediaOutput,
        s.ErrorOutput,
        s.DataOutput,
        s.Cell,
        s.NotebookSessionMetadata,
        s.NotebookSessionV1,
    ]
    NOTEBOOK_MESSAGES = [
        # Notebook
        n.NotebookCell,
        n.NotebookCellConfig,
        n.NotebookMetadata,
        n.NotebookV1,
    ]


@app.cell(hide_code=True)
def _():
    mo.md(
        r"""This marimo notebook generates the OpenAPI schema for the `TypeDict`s defined in `marimo._schemas`"""
    )
    return


@app.cell(hide_code=True)
def _():
    generate_schema_button = mo.ui.run_button(label="Write schema")
    generate_schema_button
    return (generate_schema_button,)


@app.function
def generate_schema(name: str):
    import yaml

    assert name == "session" or name == "notebook", (
        "Invalid schema name must be 'session' or 'notebook'"
    )

    header = f"# This file is generated by scripts/generate_schemas.py\n"

    print(f"Writing schema to marimo/_schemas/generated/{name}.yaml")
    messages = SESSION_MESSAGES if name == "session" else NOTEBOOK_MESSAGES
    return header + yaml.dump(build_openapi_schema(messages))


@app.function
def write_schema(name: str):
    output = (
        mo.notebook_dir().parent
        / "marimo"
        / "_schemas"
        / "generated"
        / f"{name}.yaml"
    )
    output.write_text(generate_schema(name), encoding="utf-8")


@app.cell
def _(generate_schema_button):
    if mo.app_meta().mode == "script" or generate_schema_button.value:
        write_schema("session")
        write_schema("notebook")
    return


@app.function(hide_code=True)
def build_openapi_schema(messages):
    from marimo._server.api.router import build_routes
    from marimo._utils.dataclass_to_openapi import (
        PythonTypeToOpenAPI,
    )

    processed_classes: dict[t.Any, str] = {}
    component_schemas: dict[str, t.Any] = {}
    name_overrides: dict[t.Any, str] = {}

    converter = PythonTypeToOpenAPI(
        camel_case=False, name_overrides=name_overrides
    )
    for cls in messages:
        if cls in processed_classes:
            del processed_classes[cls]
        name = name_overrides.get(cls, cls.__name__)  # type: ignore[attr-defined]
        component_schemas[name] = converter.convert(cls, processed_classes)
        processed_classes[cls] = name

    schemas = {
        "openapi": "3.0.0",
        "info": {"title": "marimo_session"},
        "components": {
            "schemas": {
                **component_schemas,
            }
        },
    }
    return schemas


if __name__ == "__main__":
    app.run()
